Skip to content

Adding SchemaTrackingMixin#1109

Merged
marcromeyn merged 7 commits intomainfrom
torch/track-schema
May 29, 2023
Merged

Adding SchemaTrackingMixin#1109
marcromeyn merged 7 commits intomainfrom
torch/track-schema

Conversation

@marcromeyn
Copy link
Contributor

@marcromeyn marcromeyn commented May 23, 2023

Goals ⚽

This PR introduces SchemaTrackingMixin , a mixin class for PyTorch modules to track the output shapes and dtypes of the forward pass. This is used in order to automatically generate the output-schema. It registers a hook to capture this information and provides methods to access the output schema.

@github-actions
Copy link

Documentation preview

https://nvidia-merlin.github.io/models/review/pr-1109

@marcromeyn marcromeyn added this to the Merlin 23.06 milestone May 23, 2023
@marcromeyn marcromeyn self-assigned this May 23, 2023
module._output_shapes[key] = value.shape
module._output_dtypes[key] = value.dtype
else:
module._output_shapes["output"] = output.shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this "output" key is for the case where a module outputs a single tensor? Would something that uses this output schema depend on the name of this column name elsewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, not sure actually. Maybe we should extract it in a constant somewhere?

@marcromeyn marcromeyn added the enhancement New feature or request label May 26, 2023
@marcromeyn marcromeyn marked this pull request as ready for review May 29, 2023 11:59
@marcromeyn marcromeyn merged commit b6d6645 into main May 29, 2023
@marcromeyn marcromeyn deleted the torch/track-schema branch May 29, 2023 12:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area/pytorch enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants